import sys
from sklearn.linear_model import LogisticRegression
import pandas as pd
from utils import get_dataset
from models.LRBinsModel import LRBinsModel
import numpy as np
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, accuracy_score
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('datasetname', type=str, help='the name of the dataset to process')
parser.add_argument('--make_plots', action='store_true', help='if specified, create plots of the data')
args = parser.parse_args()
datasetname = args.datasetname
make_plots = args.make_plots

data = pd.read_csv(f"data/{datasetname}.csv")
with open(f"hyperparameters/{datasetname}.p", "rb") as fp:
    hyperparameters = pickle.load(fp)

X_train, X_val, X_test, y_train, y_val, y_test, feature_names = get_dataset(
    data, normalize=True, random_state=123
)

num_features = X_train.shape[1]

# xgboost model
clf = XGBClassifier(max_depth=hyperparameters["xgb_max_depth"], n_estimators=hyperparameters["xgb_n_estimators"])
clf.fit(X_train, y_train)
y_probs = clf.predict_proba(X_test)[:, 1]
y_preds = clf.predict(X_test)
y_test = y_test.astype(int)
y_preds = y_preds.astype(int)
roc_score = roc_auc_score(y_test, y_probs)
acc_score = accuracy_score(y_test, y_preds)
curpaperxgbrocauc = roc_score
curpaperxgbacc = acc_score

# xgboost/lrbins hybrid model
X_eval = X_train
y_eval = y_train
model = LRBinsModel(
    inference_on_all_bins=False,
    fallback_model=clf,
    n_bin_features=hyperparameters["lrbins_n_bin_features"],
    n_inference_features=hyperparameters["lrbins_n_inference_features"],
    sort_with_metric="accuracy",
)
model.fit(X_train, y_train, X_eval=X_eval, y_eval=y_eval)
results = model.performance(X_test, y_test)
thresholds = model.thresholds

coverages = []
rocaucs = []
accuracies = []
for threshold in thresholds:
    model = LRBinsModel(
        inference_on_all_bins=False,
        fallback_model=clf,
        first_stage_threshold=threshold,
        n_bin_features=hyperparameters["lrbins_n_bin_features"],
        n_inference_features=hyperparameters["lrbins_n_inference_features"],
        sort_with_metric="accuracy",
    )
    model.fit(X_train, y_train, X_eval=X_eval, y_eval=y_eval)
    results = model.performance(X_test, y_test)
    rocaucs.append(results["rocauc"])
    coverages.append(results["coverage"])
    accuracies.append(results["accuracy"])
    print("xgb-hybrid rocauc:", curpaperxgbrocauc-results["rocauc"])
    print("xgb-hybrid acc:", curpaperxgbacc-results["accuracy"])
    print("coverage:", results["coverage"]*100)
    print()

coverages = [0.0] + coverages
rocaucs = [roc_score] + rocaucs
accuracies = [acc_score] + accuracies

if(make_plots):
    plt.plot(coverages, rocaucs, label="hybrid lrwbins/xgb")
    plt.plot(coverages, [roc_score for _ in coverages], label="xgb")
    plt.xlabel("Coverage")
    plt.ylabel("ROCAUC")
    plt.legend()
    plt.savefig(f"output/rocauc_{datasetname}.png")
    plt.close()

    plt.plot(coverages, accuracies, label="hybrid lrwbins/xgb")
    plt.plot(coverages, [acc_score for _ in coverages], label="xgb")
    plt.xlabel("Coverage")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.savefig(f"output/accuracy_{datasetname}.png")
    plt.close()